package com.hyjiacan.tools.apps.dbm.utils

import com.hyjiacan.tools.apps.dbm.models.DataTable
import com.hyjiacan.tools.apps.dbm.models.ParamItem
import com.hyjiacan.tools.apps.dbm.services.MysqlConnections
import com.hyjiacan.tools.http.entities.QueryResult
import com.hyjiacan.tools.utils.Json
import com.mysql.cj.jdbc.ClientPreparedStatement
import java.io.BufferedInputStream
import java.io.InputStream
import java.sql.*
import java.sql.Array
import java.time.format.DateTimeFormatter

object MySQL {
    @Throws(SQLException::class)
    fun connect(conOpts: String?, callback: (Connection) -> Unit) {
        var conn: Connection? = null
//        var stmt: Statement? = null
        try {
            Class.forName("com.mysql.cj.jdbc.Driver")
            val connection = Misc.decodeOption(conOpts)

            // 创建连接，涉及数据库IP，端口，数据库名，字符集，账号及密码
            val url = String.format(
                "jdbc:mysql://%s:%d/%s?nullCatalogMeansCurrent=true&useUnicode=%s&characterEncoding=%s&allowMultiQueries=true&serverTimezone=%s&useSSL=%s",
                connection.host,
                connection.port,
                connection.database,
                connection.unicode,
                connection.encoding,
                connection.timezone,
                connection.useSSL,
            )

//            if (connection.timeout > 0) {
//                DriverManager.setLoginTimeout(connection.timeout)
//            }
            conn = DriverManager.getConnection(
                url, connection.user, connection.password
            )

            callback(conn)
        } catch (e: Exception) {
            e.printStackTrace()
            if (e.message.isNullOrBlank()) {
                throw Exception("未知错误")
            }
            throw e
        } finally {
//            stmt?.close()
            conn?.close()
        }
    }

    @Throws(SQLException::class)
    fun execute(
        conOpts: String?,
        statement: String?,
        fetchOffset: Int,
        fetchSize: Int,
        args: List<ParamItem>? = listOf()
    ): QueryResult {
        val qr = QueryResult()
        val conn: Connection?
        var stmt: PreparedStatement? = null
        try {
            Class.forName("com.mysql.cj.jdbc.Driver")
            val connection = Misc.decodeOption(conOpts)

            // 创建连接，涉及数据库IP，端口，数据库名，字符集，账号及密码
            val url = String.format(
                "jdbc:mysql://%s:%d/%s?nullCatalogMeansCurrent=true&useUnicode=%s&characterEncoding=%s&allowMultiQueries=true&serverTimezone=%s&useSSL=%s",
                connection.host,
                connection.port,
                connection.database,
                connection.unicode,
                connection.encoding,
                connection.timezone,
                connection.useSSL,
            )
            // 不需要执行SQL，能连接成功就行
            if (statement == null) {
                qr.success = true
                return qr
            }
            if (connection.timeout > 0) {
                DriverManager.setLoginTimeout(connection.timeout)
            }
            conn = MysqlConnections.get(url) ?: DriverManager.getConnection(
                url, connection.user, connection.password
            )
            MysqlConnections.set(url, conn!!)
            @Suppress("SqlSourceToSinkFlow")
            stmt = conn.prepareStatement(statement, ResultSet.TYPE_SCROLL_SENSITIVE, ResultSet.CONCUR_READ_ONLY)

            args?.forEachIndexed { idx, it ->
                val paramValue = it.value
                val paramType = it.type
                val sqlType = getSqlTypeByDbType(paramType)

                val i = idx + 1

                if (paramValue == null) {
                    stmt.setNull(i, sqlType)
                    return@forEachIndexed
                }

                when (sqlType) {
                    Types.ARRAY -> {
                        stmt.setArray(i, Json.parseArray(paramValue.toString()) as Array)
                    }

                    Types.BIGINT -> {
                        stmt.setLong(i, paramValue.toString().toLong())
                    }

                    Types.BIT -> {
                        stmt.setBoolean(i, paramValue.toString() == "1")
                    }

                    Types.BINARY -> {
                        stmt.setBytes(i, hex2bin(paramValue.toString()))
                    }

                    Types.BLOB -> {
                        stmt.setBytes(i, hex2bin(paramValue.toString()))
                    }

                    Types.BOOLEAN -> {
                        stmt.setBoolean(i, paramValue.toString() == "1")
                    }

                    Types.CHAR -> {
                        stmt.setString(i, paramValue.toString())
                    }

                    Types.DATE -> {
                        stmt.setDate(i, Date.valueOf(paramValue.toString()))
                    }

                    Types.DECIMAL -> {
                        stmt.setBigDecimal(i, paramValue.toString().toBigDecimal())
                    }

                    Types.DOUBLE -> {
                        stmt.setDouble(i, paramValue.toString().toDouble())
                    }

                    Types.FLOAT -> {
                        stmt.setFloat(i, paramValue.toString().toFloat())
                    }

                    Types.INTEGER -> {
                        stmt.setInt(i, paramValue.toString().toInt())
                    }

                    Types.LONGNVARCHAR -> {
                        stmt.setString(i, paramValue.toString())
                    }

                    Types.LONGVARBINARY -> {
                        stmt.setBytes(i, hex2bin(paramValue.toString()))
                    }

                    Types.LONGVARCHAR -> {
                        stmt.setString(i, paramValue.toString())
                    }

                    Types.NCHAR -> {
                        stmt.setString(i, paramValue.toString())
                    }

                    Types.NUMERIC -> {
                        stmt.setBigDecimal(i, paramValue.toString().toBigDecimal())
                    }

                    Types.NVARCHAR -> {
                        stmt.setString(i, paramValue.toString())
                    }

                    Types.SMALLINT -> {
                        stmt.setShort(i, paramValue.toString().toShort())
                    }

                    Types.TIME -> {
                        stmt.setTime(i, Time.valueOf(paramValue.toString()))
                    }

                    Types.TIMESTAMP -> {
                        stmt.setTimestamp(i, Timestamp.valueOf(paramValue.toString()))
                    }

                    Types.TINYINT -> {
                        stmt.setByte(i, paramValue.toString().toByte())
                    }

                    Types.VARBINARY -> {
                        stmt.setBytes(i, hex2bin(paramValue.toString()))
                    }

                    Types.VARCHAR -> {
                        stmt.setString(i, paramValue.toString())
                    }

                    else -> {
                        stmt.setString(i, paramValue.toString())
                    }
                }
            }
            val hasResultSet = stmt.execute()

            qr.success = true

//            val sqlCount = Sql.count(statement)
            val dataset = mutableListOf<DataTable>()
            if (hasResultSet) {
                do {
                    val result = stmt.resultSet ?: break
                    result.absolute(-1)
                    val rowsAffected = result.row
                    result.absolute(fetchOffset.coerceAtLeast(0))
                    if (fetchSize > 0) {
                        result.fetchSize = fetchSize
                    }
                    val columns: MutableList<String> = ArrayList()
                    val meta = result.metaData
                    val colCount = meta.columnCount
                    val tables = arrayListOf<String>()
                    val databases = arrayListOf<String>()
                    for (i in 0 until colCount) {
                        columns.add(meta.getColumnName(i + 1))
                        val table = meta.getTableName(i + 1)
                        tables.add(table)
                        val database = meta.getCatalogName(i + 1)
                        databases.add(database)
                    }
                    val rows: ArrayList<Any> = ArrayList()
                    while (result.next()) {
                        val row: MutableList<Any?> = ArrayList()
                        for (i in 0 until colCount) {
                            row.add(result.getObject(i + 1))
                        }
                        rows.add(row)
                        if (fetchSize > 0 && rows.size == fetchSize) {
                            break
                        }
                    }
                    val dt = DataTable()
                    dt.query = true
                    dt.tables = tables
                    dt.databases = databases
                    dt.data = rows
                    dt.columns = columns
                    dt.rowsAffected = rowsAffected
                    dataset.add(dt)
                } while (stmt.moreResults)
            } else {
                val table = DataTable()
                if (stmt is ClientPreparedStatement) {
                    // 添加数据
                    table.insertId = stmt.lastInsertID
                    // 修改数据
                    table.rowsAffected = stmt.updateCount
                }
                dataset.add(table)
            }
            qr.data = dataset
        } catch (e: Exception) {
            e.printStackTrace()
            qr.success = false
            qr.message = e.message
        } finally {
            stmt?.close()
//            conn?.close()
        }
        return qr
    }

    fun readCell(
        type: Int,
        result: ResultSet,
        i: Int
    ): Pair<String, Boolean> {
        val value: String

        if (result.getObject(i + 1) == null) {
            return Pair("NULL", false)
        }

        var isString = false
        when (type) {
            Types.ARRAY -> {
                value = Json.stringify(result.getArray(i + 1))
                isString = true
            }

            Types.BIGINT -> {
                value = result.getLong(i + 1).toString()
            }

            Types.BIT -> {
                value = if (result.getBoolean(i + 1)) {
                    "1"
                } else {
                    "0"
                }
            }

            Types.BINARY -> {
                value = bin2hex(result.getBinaryStream(i + 1))
            }

            Types.BLOB -> {
                value = bin2hex(result.getBlob(i + 1).binaryStream)
            }

            Types.BOOLEAN -> {
                value = if (result.getBoolean(i + 1)) {
                    "1"
                } else {
                    "0"
                }
            }

            Types.CHAR -> {
                value = result.getCharacterStream(i + 1).readText()
                isString = true
            }

            Types.DATE -> {
                value = result.getDate(i + 1).toLocalDate().format(DateTimeFormatter.ISO_LOCAL_DATE)
                isString = true
            }

            Types.DECIMAL -> {
                value = result.getBigDecimal(i + 1).toString()
            }

            Types.DOUBLE -> {
                value = result.getDouble(i + 1).toString()
            }

            Types.FLOAT -> {
                value = result.getFloat(i + 1).toString()
            }

            Types.INTEGER -> {
                value = result.getInt(i + 1).toString()
            }

            Types.LONGNVARCHAR -> {
                value = result.getNString(i + 1)
                isString = true
            }

            Types.LONGVARBINARY -> {
                value = bin2hex(result.getBinaryStream(i + 1))
            }

            Types.LONGVARCHAR -> {
                value = result.getString(i + 1).toString()
                isString = true
            }

            Types.NCHAR -> {
                value = result.getNCharacterStream(i + 1).readText()
                isString = true
            }

            Types.NULL -> {
                value = "NULL"
            }

            Types.NUMERIC -> {
                value = result.getBigDecimal(i + 1).toString()
            }

            Types.NVARCHAR -> {
                value = result.getNString(i + 1).toString()
                isString = true
            }

            Types.SMALLINT -> {
                value = result.getShort(i + 1).toString()
            }

            Types.TIME -> {
                value = result.getTime(i + 1).toLocalTime().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)
                isString = true
            }

            Types.TIMESTAMP -> {
                value = result.getTimestamp(i + 1).toLocalDateTime().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)
                isString = true
            }

            Types.TINYINT -> {
                value = result.getByte(i + 1).toString()
            }

            Types.VARBINARY -> {
                value = bin2hex(result.getBinaryStream(i + 1))
            }

            Types.VARCHAR -> {
                value = result.getString(i + 1)
                isString = true
            }

            else -> {
                value = result.getObject(i + 1).toString()
                isString = true
            }
        }
        return Pair(value, isString)
    }


    @OptIn(ExperimentalStdlibApi::class)
    private fun bin2hex(stream: InputStream?): String {
        if (stream == null) {
            return "NULL"
        }

        val result = StringBuilder()
        result.append("0x")

        val bufferSize = 1024 // 设置合适的缓冲区大小
        val buffer = ByteArray(bufferSize)
        var bytesRead: Int

        BufferedInputStream(stream, bufferSize).use { bufferedStream ->
            while (bufferedStream.read(buffer).also { bytesRead = it } != -1) {
                for (i in 0 until bytesRead) {
                    val hex = (buffer[i].toInt() and 0xFF).toString(16).uppercase()
                    result.append(hex.padStart(2, '0'))
                }
            }
        }

        return result.toString()
    }

    private fun isHexDigit(char: Char): Boolean {
        return char in '0'..'9' || char in 'A'..'F' || char in 'a'..'f'
    }

    private fun hex2bin(hexString: String): ByteArray {
        // 去除空值、特殊字符串以及不符合十六进制格式开头的情况
        if (hexString.isBlank() || hexString == "0x") {
            return ByteArray(0)
        }

        // 去除前缀（如果有）并验证整体字符串合法性
        val hexWithoutPrefix = if (hexString.startsWith("0x")) hexString.substring(2) else hexString
        if (!hexWithoutPrefix.all(::isHexDigit)) {
            throw IllegalArgumentException("输入的字符串不是合法的十六进制格式")
        }

        // 处理长度为奇数的十六进制字符串情况，补充前缀0使其长度为偶数
        val normalizedHex = if (hexWithoutPrefix.length % 2 == 1) "0$hexWithoutPrefix" else hexWithoutPrefix

        val byteArray = ByteArray(normalizedHex.length / 2)
        val hexChars = normalizedHex.toCharArray()
        for ((index, i) in (hexChars.indices step 2).withIndex()) {
            val highNibble = hexToByte(hexChars[i])
            val lowNibble = hexToByte(hexChars[i + 1])
            byteArray[index] = ((highNibble.toInt() shl 4) or lowNibble.toInt()).toByte()
        }
        return byteArray
    }

    private fun hexToByte(hexChar: Char): Byte {
        return when (hexChar) {
            in '0'..'9' -> (hexChar - '0').toByte()
            in 'A'..'F' -> (10 + (hexChar - 'A')).toByte()
            in 'a'..'f' -> (10 + (hexChar - 'a')).toByte()
            else -> throw IllegalArgumentException("Invalid hexadecimal character: $hexChar")
        }
    }

    private fun getSqlTypeByDbType(dbType: String): Int {
        when (dbType.toUpperCase()) {
            "ARRAY" -> return Types.ARRAY
            "BIT" -> return Types.BIT
            "BLOB" -> return Types.BLOB
            "BIGINT" -> return Types.BIGINT
            "BINARY" -> return Types.BINARY
            "BOOLEAN" -> return Types.BOOLEAN
            "CHAR" -> return Types.CHAR
            "CLOB" -> return Types.CLOB
            "DATE" -> return Types.DATE
            "DOUBLE" -> return Types.DOUBLE
            "DECIMAL" -> return Types.DECIMAL
            "DATALINK" -> return Types.DATALINK
            "FLOAT" -> return Types.FLOAT
            "INTEGER" -> return Types.INTEGER
            "LONGNVARCHAR" -> return Types.LONGNVARCHAR
            "LONGVARCHAR" -> return Types.LONGVARCHAR
            "LONGVARBINARY" -> return Types.LONGVARBINARY
            "NCHAR" -> return Types.NCHAR
            "NVARCHAR" -> return Types.NVARCHAR
            "NUMERIC" -> return Types.NUMERIC
            "NCLOB" -> return Types.NCLOB
            "SMALLINT" -> return Types.SMALLINT
            "SQLXML" -> return Types.SQLXML
            "STRUCT" -> return Types.STRUCT
            "TIME" -> return Types.TIME
            "TIMESTAMP" -> return Types.TIMESTAMP
            "TINYINT" -> return Types.TINYINT
            "TIME_WITH_TIMEZONE" -> return Types.TIME_WITH_TIMEZONE
            "TIMESTAMP_WITH_TIMEZONE" -> return Types.TIMESTAMP_WITH_TIMEZONE
            "VARCHAR" -> return Types.VARCHAR
            "VARBINARY" -> return Types.VARBINARY
            else -> return Types.NULL
        }
    }
}